package com.wang.transfer.util.algorithm;

import java.util.*;

/**
 * 四数之和
 * 给一个整数数组nums,和一个目标值target,请给出四个数之和等于target的四元组，不可重复
 */
public class FourSum {

    public static void main(String[] args) {
        int[] nums = new int[]{-1, 0, -5, -2, -2, -4, 0, 1, -2, 2};
//        int[] nums = new int[]{1000000000, 1000000000, 1000000000, 1000000000};
//        int[] nums = new int[]{5, -2, -2, 5, 0, 0, -1, 2};
        int target = -9;
//        int target = -294967296;
//        int target = 8;
//        System.out.println(new FourSum().fourSum(nums, target));
        System.out.println(new FourSum().fourSum2(nums, target));
    }

    public List<List<Integer>> fourSum(int[] nums, int target) {
        if (nums.length < 4) {
            return new ArrayList<>();
        }
        Set<List<Integer>> lists = new HashSet<>();
        Arrays.sort(nums);
        for (int i = 0; i < nums.length; i++) {
            for (int j = i + 1; j < nums.length; j++) {
                String tar = String.valueOf((long) target - nums[i] - nums[j]);
                int z = j + 1, k = nums.length - 1;
                while (z < k) {
                    String t = String.valueOf(nums[z] + nums[k]);
                    if (tar.equals(t))
                        lists.add(Arrays.asList(nums[i], nums[j], nums[z], nums[k--]));
//                    else if (new BigDecimal(tar).intValue() < new BigDecimal(t).intValue())
//                        k--;
//                    else
//                        z++;
                    else {
                        if (tar.charAt(0) == '-') {
                            if (t.charAt(0) == '-') {
                                if (tar.length() > t.length()) {
                                    k--;
                                } else if (tar.length() < t.length()) {
                                    z++;
                                } else if (tar.substring(1).compareTo(t.substring(1)) > 0) {
                                    k--;
                                } else {
                                    z++;
                                }
                            } else {
                                k--;
                            }
                        } else {
                            if (t.charAt(0) == '-') {
                                z++;
                            } else if (tar.length() > t.length()) {
                                z++;
                            } else if (tar.length() < t.length()) {
                                k--;
                            } else if (tar.compareTo(t) > 0) {
                                z++;
                            } else {
                                k--;
                            }
                        }
                    }
                }
            }
        }
        return new ArrayList<>(lists);
    }

    public List<List<Integer>> fourSum2(int[] nums, int target) {
        List<List<Integer>> list = new ArrayList<>();
        if (nums == null || nums.length < 4) {
            return list;
        }
        Arrays.sort(nums);
        int n = nums.length;
        for (int i = 0; i < n - 3; i++) {
            if (i > 0 && nums[i] == nums[i - 1]) {
                continue;
            }
            if ((long) nums[i] + nums[i + 1] + nums[i + 2] + nums[i + 3] > target) {
                break;
            }
            if ((long) nums[i] + nums[n - 3] + nums[n - 2] + nums[n - 1] < target) {
                continue;
            }
            for (int j = i + 1; j < n - 2; j++) {
                if (j > i + 1 && nums[j] == nums[j - 1]) {
                    continue;
                }
                if ((long) nums[i] + nums[j] + nums[j + 1] + nums[j + 2] > target) {
                    break;
                }
                if ((long) nums[i] + nums[j] + nums[n - 2] + nums[n - 1] < target) {
                    continue;
                }
                long tar = (long) target - nums[i] - nums[j];
                int z = j + 1, k = n - 1;
                while (z < k) {
                    long t = (long) nums[z] + nums[k];
                    if (tar == t) {
                        list.add(Arrays.asList(nums[i], nums[j], nums[z], nums[k]));
                        while (z < k && nums[z] == nums[z + 1]) {
                            z++;
                        }
                        z++;
                        while (z < k && nums[k] == nums[k - 1]) {
                            k--;
                        }
                        k--;
                    } else if (tar < t) {
                        k--;
                    } else {
                        z++;
                    }
                }
            }
        }
        return list;
    }
}
